# Copyright 2021-2023 @ Shenzhen Bay Laboratory &
#                       Peking University &
#                       Huawei Technologies Co., Ltd
#
# This code is a part of MindSPONGE:
# MindSpore Simulation Package tOwards Next Generation molecular modelling.
#
# MindSPONGE is open-source software based on the AI-framework:
# MindSpore (https://www.mindspore.cn/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Controller
"""

from typing import Union, Tuple, List
from numpy import ndarray
import torch
from torch import Tensor
from torch.nn import Module

from ..system import Molecule
from ..function import functions as func
from ..function.functions import get_integer, get_ms_array, get_arguments, GLOBAL_DEVICE_SPONGE


class Controller(Module):
    r"""
    Base class for the controller module in MindSPONGE.
    The `Controller` used in `Updater` to control the values of seven variables during the simulation
    process: coordinate, velocity, force, energy, kinetics, virial and pbc_box.

    Args:
        system(Molecule):   Simulation system
        control_step(int):  Step interval for controller execution. Default: 1

    Inputs:
        - **coordinate** (Tensor) - Tensor of shape `(B, A, D)`. Data type is float.
        - **velocity** (Tensor) - Tensor of shape `(B, A, D)`. Data type is float.
        - **force** (Tensor) - Tensor of shape `(B, A, D)`. Data type is float.
        - **energy** (Tensor) - Tensor of shape `(B, 1)`. Data type is float.
        - **kinetics** (Tensor) - Tensor of shape `(B, D)`. Data type is float.
        - **virial** (Tensor) - Tensor of shape `(B, D)`. Data type is float.
        - **pbc_box** (Tensor) - Tensor of shape `(B, D)`. Data type is float.
        - **step** (int) - Simulation step. Default: 0

    Outputs:
        - coordinate, Tensor of shape `(B, A, D)`. Data type is float.
        - velocity, Tensor of shape `(B, A, D)`. Data type is float.
        - force, Tensor of shape `(B, A, D)`. Data type is float.
        - energy, Tensor of shape `(B, 1)`. Data type is float.
        - kinetics, Tensor of shape `(B, D)`. Data type is float.
        - virial, Tensor of shape `(B, D)`. Data type is float.
        - pbc_box, Tensor of shape `(B, D)`. Data type is float.

    Supported Platforms:
        ``Ascend`` ``GPU``
    """
    def __init__(self,
                 system: Molecule,
                 control_step: int = 1,
                 **kwargs,
                 ):

        super().__init__()
        self._kwargs = get_arguments(locals(), kwargs)

        self.device = GLOBAL_DEVICE_SPONGE()
        self.system = system
        self.num_walker = self.system.num_walker
        self.num_atoms = system.num_atoms
        self.dimension = system.dimension

        self.sys_dofs = system.degrees_of_freedom
        self.degrees_of_freedom = system.degrees_of_freedom

        self.time_step = get_ms_array(1e-3, dtype=torch.float32, device=self.device)

        self._coordinate = self.system.coordinate.to(self.device)
        self._pbc_box = self.system.pbc_box.to(self.device) if self.system.pbc_box is not None else None

        self.units = self.system.units
        self.kinetic_unit_scale = get_ms_array(self.units.kinetic_ref, dtype=torch.float32, device=self.device)
        self.press_unit_scale = get_ms_array(self.units.pressure_ref, dtype=torch.float32, device=self.device)

        # (B, A)
        self.atom_mass = self.system.atom_mass.to(self.device)
        self.inv_mass = self.system.inv_mass.to(self.device)
        # (B, A, 1)
        self._atom_mass = self.atom_mass.unsqueeze(-1)
        self._inv_mass = self.inv_mass.unsqueeze(-1)

        # (B, 1)
        self.system_mass = self.system.system_mass.to(self.device)
        self.system_natom = self.system.system_natom

        self.control_step = get_integer(control_step)
        if self.control_step <= 0:
            raise ValueError('The "control_step" must be larger than 0!')

        self.num_constraints = 0

    @property
    def boltzmann(self) -> float:

        return self.units.boltzmann

    def set_time_step(self, dt: float):
        
        self.time_step = get_ms_array(dt, torch.float32)
        return self

    def set_degrees_of_freedom(self, dofs: int):

        self.degrees_of_freedom = get_integer(dofs)
        return self

    def update_coordinate(self, coordinate: Tensor) -> Tensor:

        return self._coordinate.data.copy_(coordinate)

    def update_pbc_box(self, pbc_box: Tensor) -> Tensor:

        if self._pbc_box is None:
            return pbc_box
        return self._pbc_box.data.copy_(pbc_box)

    def get_kinetics(self, velocity: Tensor) -> Tensor:
        r"""
        Calculate kinetics according to velocity.

        Args:
            velocity(Tensor):   Tensor of atomic velocities. Tensor shape is `(B, A, D)`.
                                Data type is float.

        Returns:
            Tensor, Tensor of kinetics. Tensor shape is `(B, D)`. Data type is float.
        """
        if velocity is None:
            return None
        # (B, A, D) * (B, A, 1)
        k = 0.5 * self._atom_mass * velocity**2
        # (B, D) <- (B, A, D)
        kinetics = torch.sum(k, -2)
        return kinetics * self.kinetic_unit_scale

    def get_temperature(self, kinetics: Tensor = None) -> Tensor:
        r"""
        Calculate temperature according to velocity.

        Args:
            kinetics(Tensor):   Tensor of kinetics. Tensor shape is `(B, D)`. Data type is float. Default: None

        Returns:
            Tensor, Tensor of temperature. The shape of the Tensor is `(B)`. Data type is float.
        """
        if kinetics is None:
            return None
        # (B) <- (B, D)
        kinetics = torch.sum(kinetics, -1)
        return 2 * kinetics / self.degrees_of_freedom / self.boltzmann

    def get_volume(self, pbc_box: Tensor) -> Tensor:

        if self._pbc_box is None:
            return None
        # (B, 1) <- (B, D)
        return torch.prod(pbc_box, -1, True)

    def get_pressure(self, kinetics: Tensor, virial: Tensor, pbc_box: Tensor) -> Tensor:

        if self._pbc_box is None:
            return None
        volume = torch.prod(pbc_box, -1, True)
        # (B, D) = ((B, D) - (B, D)) / (B, 1)
        pressure = 2 * (kinetics - virial) / volume
        return pressure * self.press_unit_scale

    def get_com(self, coordinate: Tensor, keepdims: bool = True) -> Tensor:
       
        # (B, A, D) = (B, A, D) * (B, A, 1)
        weight_coord = coordinate * self._atom_mass
        if keepdims:
            # (B, 1, D) <- (B, A, D)
            tot_coord = torch.sum(weight_coord, -2, keepdim=True)
            # (B, 1, 1) <- (B, 1)
            tot_mass = self.system_mass.unsqueeze(-1)
        else:
            # (B, D) <- (B, A, D)
            tot_coord = torch.sum(weight_coord, -2)
            # (B, 1)
            tot_mass = self.system_mass

        # (B, 1, D) = (B, 1, D) / (B, 1, 1)
        # OR
        # (B, D) = (B, D) / (B, 1)
        com = tot_coord / tot_mass
        return com

    def get_com_velocity(self, velocity: Tensor, keepdims: bool = True) -> Tensor:

        # (B, A, D) = (B, A, D) * (B, A, 1)
        weight_vel = velocity * self._atom_mass
        if keepdims:
            # (B, 1, D) <- (B, A, D)
            tot_vel = torch.sum(weight_vel, -2, keepdim=True)
            # (B, 1, 1) <- (B, 1)
            tot_mass = self.system_mass.unsqueeze(-1)
        else:
            # (B, D) <- (B, A, D)
            tot_vel = torch.sum(weight_vel, -2)
            # (B, 1)
            tot_mass = self.system_mass

        # (B, 1, D) = (B, 1, D) / (B, 1, 1)
        # OR
        # (B, D) = (B, D) / (B, 1)
        com_vel = tot_vel / tot_mass
        return com_vel

    def forward(self,
                  coordinate: Tensor,
                  velocity: Tensor,
                  force: Tensor,
                  energy: Tensor,
                  kinetics: Tensor,
                  virial: Tensor = None,
                  pbc_box: Tensor = None,
                  step: int = 0,
                  ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        #pylint: disable=unused-argument

        return coordinate, velocity, force, energy, kinetics, virial, pbc_box

    def _get_mw_tensor(self,
                       value: Union[float, ndarray, Tensor, List[float]],
                       name: str) -> Tensor:
        """get tensor for multiple walkers"""
        value = func.get_tensor(value, torch.float32, device=self.device)
        if value.size(0) == 1:
            # ()
            return torch.reshape(value, ())

        # if value.size != self.num_walker:
        #     error_info = f'The size of {name} must be equal to 1 or ' \
        #                  f'the number of multiple walker ({self.num_walker}) but got '
        #     _raise_value_error(error_info, value.size)
        # (B, 1)
        return torch.reshape(value, (self.num_walker, 1))
